#include "lqr.h"
#include "tools.h"
#include <iostream>
#include <iomanip>
#include <cmath>
#include <cstdlib>
#include <time.h>
#include <mgl2/mgl.h>
#include <stdio.h>

using namespace arma;

void lqrSolver::setA_t(MatFunction _A_t)
{
	_A_t.copyTo(A_t);
}

void lqrSolver::setB_t(MatFunction _B_t)
{
	_B_t.copyTo(B_t); 
}

void lqrSolver::setS_t(MatFunction _S_t)
{
	_S_t.copyTo(S_t);
}

void lqrSolver::setQ_t(MatFunction _Q_t)
{
	_Q_t.copyTo(Q_t); 
}

void lqrSolver::setR_t(MatFunction _R_t)
{
	_R_t.copyTo(R_t); 
}

void lqrSolver::solve()
{
	init();

	switch(type)
	{
		case TIME_INVARIANT_FINITE:		
			for (int i = 1; i < length; ++i)
			{
				XYn[i].set_size(dim*2,dim);
				// notice that the step for P is negatibve!
				// use spim instead XYn[i] = inv_expand(-(0.5*step*H), 5) * (eye<mat>(2*dim,2*dim) - 0.5*step*H) * XYn[i-1];
				XYn[i] = spim(H, -step) * XYn[i-1];
				Pn[i].set_size(dim,dim);
				Pn[i] = XYn[i].rows(dim, 2*dim-1) * inv(XYn[i].rows(0, dim-1));
			}

			for (int i = 1; i < length; ++i)
			{
				// use spim instead x[i] = inv_expand(0.5*step * (A + H.submat(0, dim, dim-1, 2*dim-1)*Pn[length-i-1]), 5)  
				x[i] = spim(A + H.submat(0, dim, dim-1, 2*dim-1)*Pn[length-i-1], step) * x[i-1];
				// u[i] = -inv(R)*trans(B)*Pn[i]*x[i];
				// u[i] = -inv(R)*trans(B)*Pn[length-i-1]*x[i];
			}

			for (int i = 0; i<length; i++)
			{
				u[i] = -inv(R)*trans(B)*Pn[length-i-1]*x[i];
			}
			
			break;

		case TIME_INVARIANT_INFINITE:
		{	// this pair of brace is used for avoiding  crosses initialization
			mat Z, U;
			Z.set_size(dim*2,dim*2);
			Z.submat(0, 0, dim-1, dim-1) = A;
			Z.submat(0, dim, dim-1, 2*dim-1) = -B*inv(R)*trans(B);
			Z.submat(dim, 0, 2*dim-1, dim-1) = -Q;
			Z.submat(dim, dim, 2*dim-1, 2*dim-1) = -trans(A);
			Schur(U, Z);
			//U.print("U=");
			Pn[0] = U.submat(dim, 0, 2*dim-1, dim-1) * inv(U.submat(0, 0, dim-1, dim-1));
			for (int i = 1; i < length; ++i)
			{
				Pn[i] = Pn[0];
			}
			u[0] = -inv(R)*trans(B)*Pn[0]*x[0];

			for (int i = 1; i < length; ++i)
			{
				// use spim instead x[i] = inv_expand(0.5*step * (A + H.submat(0, dim, dim-1, 2*dim-1)*Pn[0]), 5) * (In + 0.5*step*(A+H.submat(0, dim, dim-1, 2*dim-1)*Pn[0])) * x[i-1];
				x[i] = spim(A + H.submat(0, dim, dim-1, 2*dim-1)*Pn[0], step) * x[i-1];
				u[i] = -inv(R)*trans(B)*Pn[0]*x[i];
			}

			break;
		}

		case TIME_VARIANT_FINITE:
			for (int i = 1; i < length; ++i)
			{
				XYn[i].set_size(dim*2,dim);
				//  use spim instead XYn[i] = inv_expand((-0.5*step*Hn[i]), 5) * (eye<mat>(2*dim,2*dim) - 0.5*step*Hn[i]) * XYn[i-1];
				XYn[i] = spim(Hn[i], -step) * XYn[i-1];
				Pn[i].set_size(dim,dim);
				Pn[i] = XYn[i].rows(dim, 2*dim-1) * inv(XYn[i].rows(0, dim-1));
			}	

			for (int i = 1; i < length; ++i)
			{
				// std::cout << "An.size() " << An.size() << "\n";
				//  use spim instead x[i] = inv_expand(0.5*step * (An[i] + Hn[i].submat(0, dim, dim-1, 2*dim-1)*Pn[length-i-1]), 5) * (In + 0.5*step*(An[i]+Hn[i].submat(0, dim, dim-1, 2*dim-1)*Pn[length-i-1])) * x[i-1];
				x[i] = spim(An[i] + Hn[i].submat(0, dim, dim-1, 2*dim-1)*Pn[length-i-1], step) * x[i-1];
				// std::cout << "An.size() " << An.size() << "\n";
				// u[i] = -inv(Rn[i])*trans(Bn[i])*Pn[i]*x[i];
			}

			for (int i = 0; i<length; i++)
			{
				u[i] = -inv(Rn[i])*trans(Bn[i])*Pn[length-i-1]*x[i];
			}

			break;
	}

}


void lqrSolver::init()
{
	switch( type )
	{
		case TIME_INVARIANT_FINITE:
		
			dim = A.n_rows;

			// initialize H;
			H.set_size(dim*2, dim*2);
			H.submat(0, 0, dim-1, dim-1) = A;
			H.submat(0, dim, dim-1, 2*dim-1) = -B*inv(R)*trans(B);
			H.submat(dim, 0, 2*dim-1, dim-1) = -Q;
			H.submat(dim, dim, 2*dim-1, 2*dim-1) = -trans(A);

			// initialize the vectors
			length = round((t1 - t0)/step) + 1; 	// verify the length is good or not in the future
			XYn.resize(length);
			Pn.resize(length);
			x.resize(length);
			u.resize(length);
			
			// submit the BC
			In = eye<mat>(dim,dim);
			XYn[0].set_size(dim*2,dim); 
			XYn[0].rows(0, dim-1) = In;
			XYn[0].rows(dim, 2*dim-1) = S;
			Pn[0] = S;
			
			x[0] = x0;
			//u[0] = -inv(R)*trans(B)*Pn[0]*x[0];	

			break;

		case TIME_INVARIANT_INFINITE:
			dim = A.n_rows;

			// initialize H;
			H.set_size(dim*2, dim*2);
			H.submat(0, 0, dim-1, dim-1) = A;
			H.submat(0, dim, dim-1, 2*dim-1) = -B*inv(R)*trans(B);
			H.submat(dim, 0, 2*dim-1, dim-1) = -Q;
			H.submat(dim, dim, 2*dim-1, 2*dim-1) = -trans(A);

			// initialize the vectors
			length = round((t1 - t0)/step) + 1; 	// verify the length is good or not in the future
			//XYn.resize(length);
			Pn.resize(length);
			x.resize(length);
			u.resize(length);
			
			// submit the BC
			In = eye<mat>(dim,dim);
			//XYn[0].set_size(dim*2,dim); 
			//XYn[0].rows(0, dim-1) = In;
			//XYn[0].rows(dim, 2*dim-1) = S;
			//Pn[0] = S;
			
			x[0] = x0;
			//u[0] = -inv(R)*trans(B)*Pn[0]*x[0];	
			//std::cout << "A.size() " << A.size() << "\n";
			break;

		case TIME_VARIANT_FINITE:

			dim = A_t.n_rows;
			// std::cout << "t0 = " << t0 << ", t1 = " << t1 << "\n";

			// initialize the vectors
			length = round((t1 - t0)/step) + 1;  	// verify the length is good or not in the future
			//std::cout << "length " << length << "\n";
			// std::cout << "An.size() " << An.size() << "\n";
			An.resize(length);
			// std::cout << "An.size() " << An.size() << "\n";
			// std::cout << "An.capacity() " << An.capacity() << "\n";
			Bn.resize(length);
			Qn.resize(length);
			Rn.resize(length);
			Hn.resize(length);
			XYn.resize(length);
			Pn.resize(length);
			x.resize(length);
			u.resize(length);
			// std::cout << "Pn.size() = " << Pn.size() << "\n";
			// initialize An, Bn, Qn, Rn and Hn;
			// A_t.eval(0.05);
	 		// A_t.matVal.print("A_t.matVal");
			for (int i = 0; i < length; ++i)
			{
				update(i);
			}
			// std::cout << "line183\n";			
			// submit the BC
			In = eye<mat>(dim,dim);
			XYn[0].set_size(dim*2,dim); 
			XYn[0].rows(0, dim-1) = In;
			S_t.eval(0.0);
			XYn[0].rows(dim, 2*dim-1) = S_t.matVal;
			// std::cout << "line189\n";	
			Pn[0] = S_t.matVal;
			x[0] = x0;
			//std::cout << "Rn[0].size() = " << Rn[0].size()  << "\n";
			//std::cout << "Pn[0].size() = " << Pn[0].size()  << "\n";
			//std::cout << "Bn[0].size() = " << Bn[0].size()  << "\n";
			//u[0] = -inv(Rn[0])*trans(Bn[0])*Pn[0]*x[0];
			//std::cout << "init() succeed !\n";	
			break;   
	}	

}

void lqrSolver::update(int n)
{
	//std::cout << "A_t.matVal.size() = " << A_t.matVal.size() << "\n";
	//std::cout << "t0+n*step = " << t0+n*step << "\n";
	//A_t.matVal.print("A_t.matVal");
	//A_t.eval(t0+n*step);
	//A_t.matVal.print("A_t.matVal");
	//An[n].set_size(A_t.n_rows,A_t.n_cols);
	// An[n] = A_t.eval(t0+n*step);
	// Bn[n] = B_t.eval(t0+n*step);
	// Qn[n] = Q_t.eval(t0+n*step);
	// Rn[n] = R_t.eval(t0+n*step);

	/* Using the center point */
	double tn = t1 - (0.5+n)*step;
	An[n] = A_t.eval(tn);
	Bn[n] = B_t.eval(tn);
	Qn[n] = Q_t.eval(tn);
	Rn[n] = R_t.eval(tn);
	Hn[n].set_size(dim*2, dim*2);

	Hn[n].submat(0, 0, dim-1, dim-1) = An[n];

	Hn[n].submat(0, dim, dim-1, 2*dim-1) = -Bn[n]*inv(Rn[n])*trans(Bn[n]);

	Hn[n].submat(dim, 0, 2*dim-1, dim-1) = -Qn[n];

	Hn[n].submat(dim, dim, 2*dim-1, 2*dim-1) = -trans(An[n]);

}

void lqrSolver::saveData(const char* filename) const
{
	using namespace std;

	FILE* file = fopen(filename,"w");
	if(file == NULL){
		cout << "Could not open file " << filename << "!" << endl;
		return;
	}
	
	if (type == TIME_INVARIANT_INFINITE)
	{
		// fout << "## the solution for ARE, t in [" << t0 << "," << t1 << "], with a step " << step << ".\n\n";
		fprintf(file, "The data is stored in the order of P, u, x; P is a symmetric matrix, then only ");
		fprintf(file, "the up triangle is noted in this file. For the mat P = [p11(t) ... p1n(t); ...; pn1(t) ... pnn(t)], ");
		fprintf(file, "it is stored as a vector as in the form of [p11(t) ... p1n(t) ... pn1(t) ... pnn(t)].\n");
		fprintf(file, "t0 = %f, tf = inf, tau = %f, dim(P) = %llu, dim(u) = %llu, dim(x) = %llu\n",t0,step,A.n_rows,R.n_rows,A.n_rows);
		for (int i = 0; i < dim; ++i)
		{
			for (int j = 0; j < dim; ++j)
			{
				// fout << setprecision(12) << Pn[0](i,j) << " ";
				fprintf(file,"%.14lf,", Pn[0](i,j));			
			}
			// fout << endl;
			fprintf(file,"\n");
		}
	}
		
	else
	{
		// fout << "## the solution fo RDE, t in [" << t0 << "," << t1 << "], with a step " << step << ".\n\n";
		fprintf(file, "The data is stored in the order of P, u, x; P is a symmetric matrix, then only ");
		fprintf(file, "the up triangle is noted in this file. For the mat P = [p11(t) ... p1n(t); ...; pn1(t) ... pnn(t)], ");
		fprintf(file, "it is stored as a vector as in the form of [p11(t) ... p1n(t) ... pn1(t) ... pnn(t)].\n");
		fprintf(file, "t0 = %f, tf = %f, tau = %f, dim(P) = %llu, dim(u) = %llu, dim(x) = %llu\n",t0,t1,step,A.n_rows,R.n_rows,A.n_rows);
		for (int i = 0; i < dim; ++i)
		{
			for (int j = i; j < dim; ++j)
			{
				// fout << "## P_" << i << j << "\n\n";
				for (int k = 0; k < length-1; ++k)
				{
					// fout << setprecision(12) << Pn[length-k-1](i,j) << " ";
					fprintf(file, "%.14lf,", Pn[length-k-1](i,j));
				}
				fprintf(file, "%.14lf\n", Pn[0](i,j));	
			}
		}
	}
		
	// fout << "\n## the optimal control u(t), t in [" << t0 << "," << t1 << "], with a step " << step << ".\n\n";
	for (int j = 0; j < u[0].size(); ++j)
	{
		// fout << "## u_" << j << "\n\n";
		for (int i = 0; i < u.size()-1; ++i)
		{
			// fout << u[i](j,0) << " ";
			fprintf(file, "%.14lf,", u[i](j,0));
		}
		// fout << "\n\n";
		fprintf(file, "%.14lf\n", u[u.size()-1](j,0));
	}


	// fout << "\n## the optimal trajactory x(t), t in [" << t0 << "," << t1 << "], with a step " << step << ".\n\n";
	for (int j = 0; j < dim; ++j)
	{
		// fout << "## x_" << j << "\n\n";;
		for (int i = 0; i < x.size()-1; ++i)
		{
			// fout << x[i](j,0) << " ";
			fprintf(file, "%.14lf,", x[i](j,0));
		}
		// fout << "\n\n";
		fprintf(file, "%.14lf\n", x[x.size()-1](j,0));
	}
	
	// fout.close();
	fclose(file);
}

void lqrSolver::saveData() const
{
	char currtime[256];
	get_time(currtime);
	char *filename = strcat(currtime, ".csv");	
	saveData(filename);
}

void lqrSolver::draw() const
{
	draw_P();
	draw_u();
	draw_x();
}

void lqrSolver::draw(const char* filename_prefix) const
{
	char _filename_prefix[256];
	strcpy(_filename_prefix, filename_prefix);
	char filename1[256], filename2[256], filename3[256];
	sprintf(filename1, "%s-P(t).eps", _filename_prefix);	
	draw_P(filename1);
	sprintf(filename2, "%s-u(t).eps", _filename_prefix);
	draw_u(filename2);
	sprintf(filename3, "%s-x(t).eps", _filename_prefix);
	draw_x(filename3);
}

void lqrSolver::draw_x(const char* filename) const
{
	mglData mdata(length), t(length);
	for (int i = 0; i < length; ++i)
	{
		t.a[i] = t0 + i*step;
	}

	mglGraph gr;
	gr.Alpha(true);
	gr.Light(true);
	gr.SetRanges(t0, t1, findFloorMin(x), findCeilMax(x));
	gr.Box();
	gr.Axis();
	//gr.Grid("xy", "W");
	char legend[256];
	double *a = new double[length];
	for (int j = 0; j < dim; ++j)
	{
		for (int i = 0; i < x.size(); ++i)
		{
			a[i] = x[i](j,0);
		}
		mdata.Set(a, length);
		sprintf(legend, "{x_%d}(t)", j);
		gr.Plot(t, mdata, lineScheme[j]);
		gr.AddLegend(legend,  lineScheme[j]);
	}
	gr.Legend(3,"A#");
	gr.WriteFrame(filename);
	delete []a;
}

void lqrSolver::draw_x() const
{
	char currtime[256];
	get_time(currtime);
	char *filename = strcat(currtime, "  x(t).eps");
	draw_x(filename);
}


void lqrSolver::draw_u(const char* filename) const
{
	mglData mdata(length), t(length);
	for (int i = 0; i < length; ++i)
	{
		t.a[i] = t0 + i*step;
	}

	mglGraph gr;
	gr.Alpha(true);
	gr.Light(true);
	gr.SetRanges(t0, t1, findFloorMin(u), findCeilMax(u));
	gr.Box();
	gr.Axis();
	char legend[256];
	double *a = new double[length];
	for (int j = 0; j < u[0].size(); ++j)
	{
		for (int i = 0; i < length; ++i)
		{
			a[i] = u[i](j,0);
		}
		mdata.Set(a, length);
		sprintf(legend, "{u_%d}(t)", j);
		gr.Plot(t, mdata, lineScheme[j]);
		gr.AddLegend(legend,  lineScheme[j]);
	}
	gr.Legend(3,"A#");
	gr.WriteFrame(filename);
	delete []a;
}

void lqrSolver::draw_u() const
{
	char currtime[256];
	get_time(currtime);
	char *filename = strcat(currtime, "  u(t).eps");
	draw_u(filename);
}

void lqrSolver::draw_P(const char* filename) const
{
	mglData mdata(length), t(length);
	for (int i = 0; i < length; ++i)
	{
		t.a[i] = t0 + i*step;
	}

	mglGraph gr;
	gr.Alpha(true);
	gr.Light(true);
	gr.SetRanges(t0, t1, findFloorMin(Pn), findCeilMax(Pn));
	gr.Box();
	gr.Axis();
	char legend[256], line[256];
	double *a = new double[length];
	int count = 0;
	for (int i = 0; i < dim; ++i)
	{
		for (int j = i; j < dim; ++j)
		{
			for (int k = 0; k < length; ++k)
			{
				a[k] = Pn[length-k-1](i,j);
			}
			count++;
			mdata.Set(a, length);
			sprintf(legend, "{P_{%d%d}}(t)", i,j);
			gr.Plot(t, mdata,  lineScheme[count]);
			gr.AddLegend(legend,  lineScheme[count]);
		}
	}
	gr.Legend(3,"A#");
	gr.WriteFrame(filename);
	delete []a;
}

void lqrSolver::draw_P() const
{
	char currtime[256];
	get_time(currtime);
	char *filename = strcat(currtime, "  P(t).eps");
	draw_P(filename);
}

